import time
import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn.init as init
import torch.nn.functional as F

from data.dataset import hierarchical_dataset
from il_modules.base import BaseLearner
from modules.model import MRNNet2
from test import validation, validation_common
from tools.utils import Averager, adjust_learning_rate

EPSILON = 1e-8

init_epoch = 200
init_lr = 0.1
init_milestones = [60, 120, 170]
init_lr_decay = 0.1
init_weight_decay = 0.0005

epochs = 170
lrate = 0.1
milestones = [80, 120, 150]
lrate_decay = 0.1
batch_size = 128
weight_decay = 2e-4
num_workers = 8
T = 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class MRN(BaseLearner):

    def __init__(self, opt):
        super().__init__(opt)
        self.model = MRNNet2(opt)

        if self.opt.FeatureExtraction == "VGG":
            self.patch = 63
        elif self.opt.FeatureExtraction == "SVTR":
            self.patch = 64
        elif self.opt.FeatureExtraction == "ResNet":
            self.patch = 65

        if self.opt.Prediction == "CTC":
            self.num_steps = self.patch
        else:
            self.num_steps = self.opt.batch_max_length + 1

        if opt.checkpoint is not None:
            self.start_taski = int(opt.checkpoint.split('/')[-1].split('_')[0])
        else:
            self.start_taski = 0

    def incremental_train(self, taski, character, train_loader, valid_loader, chars_each_task):
        # pre task classes for know classes
        # self._known_classes = self._total_classes
        self.character = character
        self.converter = self.build_converter()
        self._total_classes = len(self.converter.character)
        self.taski_criterion = torch.nn.CrossEntropyLoss(reduction="mean").to(self.device)  # ignore [PAD] token
        self.taski_criterion_multi = torch.nn.MultiLabelMarginLoss(reduction="mean").to(self.device)

        if taski == 0:
            self.build_model()
            self.criterion = self.build_criterion()
        else:
            self.change_model()

        if taski > 0:
            for i in range(taski):  # freeze the parameters of previous model. range(taski) not range(taski+1).
                for p in self.model.module.model[i].parameters():
                    p.requires_grad = False
        filtered_parameters = self.count_param()  # filter that only require gradient descent
        self.build_optimizer(filtered_parameters)  # setup optimizer

        '''For load checkpoint.'''
        # if taski < self.start_taski:
        #     return
        # elif taski == self.start_taski:
        #     self.model.load_state_dict(torch.load(self.opt.checkpoint), strict=True)
        #     self._train(0, taski, train_loader, valid_loader, step=1)
        #     return
        # elif taski > self.start_taski:
        #     self._train(0, taski, train_loader, valid_loader, step=0)  # The task0 only has one step.

        self._train(0, taski, train_loader, valid_loader, chars_each_task, step=0)  # The task0 only has one step.
        if taski > 0:
            self._train(0, taski, train_loader, valid_loader, chars_each_task, step=1)

    def build_model(self):
        """ model configuration """

        self.model.build_fc(self.opt.hidden_size, self._total_classes)
        # self.model.build_fc(self.opt.output_channel, self._total_classes)
        self.model.build_prediction(self.opt, self._total_classes)

        # weight initialization
        for name, param in self.model.named_parameters():
            if "localization_fc2" in name:
                print(f"Skip {name} as it is already initialized")
                continue
            try:
                if "bias" in name:
                    init.constant_(param, 0.0)
                elif "weight" in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if "weight" in name:
                    param.data.fill_(1)
                continue

        # data parallel for multi-GPU
        self.model = torch.nn.DataParallel(self.model).to(self.device)
        self.model.train()

    def _train(self, start_iter, taski, train_loader, valid_loader, chars_each_task, step=0):
        if self.opt.start_task > taski + step * 0.5:
            name = self.opt.lan_list[taski]
            saved_best_model = f"./saved_models/{self.opt.exp_name}/{taski}_{name}_{step}_best_score.pth"
            self.model.load_state_dict(torch.load(f"{saved_best_model}"), strict=True)
            print('Task {} load checkpoint from {}.'.format(taski, saved_best_model))

            if taski > 0 and step == 0:
                train_loader.get_dataset(taski, memory=None)
                # self.freeze_step1(taski)
                # self.update_step1(0, taski, train_loader, valid_loader.create_dataset())
            elif taski > 0 and step == 1:
                if self.opt.memory != None:
                    self.build_rehearsal_memory(train_loader, taski)
                else:
                    train_loader.get_dataset(taski, memory=self.opt.memory)

        else:
            """ start training """
            if taski == 0:
                print('[Task{}] start training for model ------{}------'.format(taski, self.opt.exp_name))
                self.write_log('\n[Task{}] start training for model -------{}------'.format(taski, self.opt.exp_name))
                self._init_train(start_iter, taski, train_loader, valid_loader.create_dataset(), cross=False)
            else:
                if step == 0:
                    print('[Task{}] start training for model ------{}------'.format(taski, self.opt.exp_name))
                    self.write_log('\n[Task{}] start training for model -------{}------\n'.format(taski, self.opt.exp_name))
                    train_loader.get_dataset(taski, memory=None)  # task0 have get_dataset in init_start.
                    self._init_train(start_iter, taski, train_loader, valid_loader.create_dataset(), cross=False)
                    for p in self.model.module.model[-1].parameters():  # freeze current model.
                       p.requires_grad = False
                    self.model.module.model[-1].eval()
                else:
                    if self.opt.memory != None:
                        self.build_rehearsal_memory(train_loader, taski)
                    else:
                        train_loader.get_dataset(taski, memory=self.opt.memory)
                    self._update_representation(start_iter, taski, train_loader, valid_loader.create_list_dataset(), chars_each_task)
                    # self.model.module.weight_align(self._total_classes - self._known_classes)

    def _init_train(self, start_iter, taski, train_loader, valid_loader, cross=False):
        # loss averager
        train_loss_avg = Averager()
        start_time = time.time()
        self.best_score = -1

        # training loop
        for iteration in tqdm(
                range(start_iter + 1, self.opt.num_iter + 1),
                total=self.opt.num_iter,
                position=0,
                leave=True,
        ):
            image_tensors, labels = train_loader.get_batch()  # get_dataset -> create_dataset -> get_batch
            image = image_tensors.to(self.device)
            labels_index, labels_length = self.converter.encode(labels, batch_max_length=self.opt.batch_max_length)
            batch_size = image.size(0)

            # default recognition loss part
            if "CTC" in self.opt.Prediction:
                preds = self.model(image, cross)['logits']
                # preds = self.model(image)
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2)
                loss = self.criterion(preds_log_softmax, labels_index, preds_size, labels_length)
            else:
                preds = self.model(image, cross, labels_index[:, :-1])['logits']  # align with Attention.forward
                target = labels_index[:, 1:]  # without [SOS] Symbol
                loss = self.criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

            self.model.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.opt.grad_clip)  # gradient clipping with 5 (Default)
            self.optimizer.step()
            train_loss_avg.add(loss)

            if "super" in self.opt.schedule:
                self.scheduler.step()
            else:
                adjust_learning_rate(self.optimizer, iteration, self.opt)

            # validation part.
            # To see training progress, we also conduct validation when 'iteration == 1'
            if iteration % self.opt.val_interval == 0 or iteration == self.opt.num_iter:
                # for validation log
                self.val(valid_loader, self.opt, start_time, iteration, train_loss_avg, None, taski, 0, "FF")
                train_loss_avg.reset()

    def after_task(self):
        # will we need this line ? (AB Study)
        self.model = self.model.module
        self._known_classes = self._total_classes
        self._old_network = self.model.copy().freeze()
        # logging.info('Exemplar size: {}'.format(self.exemplar_size))

    def model_eval_and_train(self, taski):
        self.model.train()
        self.model.module.model[-1].train()
        if taski >= 1:
            for i in range(taski):
                self.model.module.model[i].eval()

    def build_custom_optimizer(self, filtered_parameters, optimizer="adam", schedule="super", scale=1.0, the=2):
        if optimizer == "sgd":
            optimizer = torch.optim.SGD(
                filtered_parameters,
                lr=self.opt.lr * scale,
                momentum=self.opt.sgd_momentum,
                weight_decay=self.opt.sgd_weight_decay,
            )
        elif optimizer == "adadelta":
            optimizer = torch.optim.Adadelta(
                filtered_parameters, lr=self.opt.lr * scale, rho=self.opt.rho, eps=self.opt.eps
            )
        elif optimizer == "adam":
            optimizer = torch.optim.Adam(filtered_parameters, lr=self.opt.lr * scale)
        # print("optimizer:")
        # print(optimizer)
        self.optimizer = optimizer
        # self.write_log(repr(optimizer) + "\n")

        if "super" in schedule:
            if optimizer == "sgd":
                cycle_momentum = True
            else:
                cycle_momentum = False

            scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=self.opt.lr * scale,
                cycle_momentum=cycle_momentum,
                div_factor=20,
                final_div_factor=1000,
                total_steps=self.opt.num_iter * the,
            )
            # print("Scheduler:")
            # print(scheduler)
            self.scheduler = scheduler
            # self.write_log(repr(scheduler) + "\n")
        elif schedule == "mlr":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer=optimizer, milestones=self.opt.milestones, gamma=self.opt.lrate_decay
            )
            self.scheduler = scheduler
            # self.write_log(repr(scheduler) + "\n")

    def change_model(self, ):
        """ model configuration """
        # model.module.reset_class(opt, device)
        if isinstance(self.model, torch.nn.DataParallel):
            self.model = self.model.module
        self.model.update_fc(self.opt.hidden_size, self._total_classes)
        # self.model.update_fc(self.opt.output_channel, self._total_classes)
        self.model.build_prediction(self.opt, self._total_classes)
        # reset_class(self.model.module, self.device)
        # data parallel for multi-GPU
        self.model = torch.nn.DataParallel(self.model).to(self.device)
        self.model.train()
        # return self.model

    def build_rehearsal_memory(self, train_loader, taski):
        # Calculate the means of old classes with newly trained network
        memory_num = self.opt.memory_num
        if memory_num >= 5000:
            num_i = memory_num
        else:
            num_i = int(memory_num / (taski))

        self.build_random_current_memory(num_i, taski, train_loader)  # modify !!!!
        '''When load the pre-trained model, the 'self.memory_index' need reload.'''
        # if taski == 1:
        #     for _taski in range(1, taski+1):
        #         self.build_random_current_memory(num_i, _taski, train_loader)
        # else:
        #     self.build_random_current_memory(num_i, taski, train_loader)

        if memory_num < 5000:
            if len(self.memory_index) != 0 and len(self.memory_index) * len(self.memory_index[0]) > memory_num:
                self.reduce_samplers(taski, taski_num=num_i)  # when taski > 1.
        train_loader.get_dataset(taski, memory=self.opt.memory, index_list=self.memory_index)
        print('[Task{}]: Rehearsal memory: {} * {} = {}'.format(taski, len(self.memory_index), self.memory_index[0].size, self.opt.memory_num))

    def load_stored_lan(self, taski, labels):
        stored_labels = []
        stored_first_lan = []
        labels_first_lan = []

        # read stored common index and common labels etc.
        root = '/data4/liuxiaoqian/Projects/MRN/data/common'
        path = str(taski) + '_' + self.opt.lan_list[taski] + '_CommonLabels.txt'
        path = os.path.join(root, path)
        f = open(path, 'r')
        line = f.readline()
        while line:
            stored_labels.append(line.split('\t')[1])
            stored_first_lan.append(line.split('\t')[-1].strip('\n').strip(' '))
            line = f.readline()
        f.close()

        for label in labels:
            lan = []
            idx_ = stored_labels.index(label)
            first_lan = stored_first_lan[idx_]  # str
            for l in first_lan:
                if l is ' ':
                    continue
                lan.append(l)
            labels_first_lan.append(lan)  # lan: list
        return labels_first_lan

    def criterion_fea_alignment(self, labels_first_lan, features):
        first_feas = []
        valid_feas = []
        for i in range(len(labels_first_lan)):
            for j in range(len(labels_first_lan[i])):
                lan = int(labels_first_lan[i][j])
                first_feas.append(features[lan][i][j])
                valid_feas.append(features[-1][i][j])
        first_feas = torch.stack(first_feas)
        valid_feas = torch.stack(valid_feas)

        dist = F.pairwise_distance(first_feas, valid_feas, 2)
        loss = dist.mean(dim=0)
        return loss

    def criterion_preds_alignment(self, labels_first_lan, preds):
        '''
        labesl_first_lan: list
        preds: [D, B, 63, C]
        '''
        first_preds = []
        valid_preds = []
        for i in range(len(labels_first_lan)):
            for j in range(len(labels_first_lan[i])):
                lan = int(labels_first_lan[i][j])
                first_preds.append(preds[lan][i][j])
                valid_preds.append(preds[-1][i][j])
        first_preds = torch.stack(first_preds)  # [*, C]
        valid_preds = torch.stack(valid_preds)  # [*, C]

        kl_criterion = torch.nn.KLDivLoss(reduction='batchmean')
        loss_kl = kl_criterion(torch.log_softmax(valid_preds, dim=-1), torch.softmax(first_preds, dim=-1))
        return loss_kl

    def get_chars_multiLabel_indexs(self, taski, labels, chars_each_task):
        '''
        convert to multi-labels according the character dict.
        return: labels_indexs: [B, 63, D]
        '''
        labels_indexs = []
        for label in labels:
            label_indexs = []
            for l in label:
                char_index = []
                for i in range(len(chars_each_task)):
                    if l in chars_each_task[i]:
                        char_index.append(i)
                if len(char_index) < (taski+1):
                    char_index.append(int(-1))
                if len(char_index) < (taski+1):
                    char_index.extend([-1] * (taski+1 - len(char_index)))
                label_indexs.append(torch.tensor(char_index))
            if len(label_indexs) < self.num_steps:  # convert to 63.
                for i in range(self.num_steps-len(label_indexs)):
                    label_indexs.append(torch.tensor([-1] * (taski+1)))

            label_indexs = torch.stack(label_indexs)
            labels_indexs.append(label_indexs)
        labels_indexs = torch.stack(labels_indexs)
        return labels_indexs.to(device)

    def get_chars_initial_indexs(self, current_indexs):
        standard_labels = []
        for id in current_indexs:
            standard_labels.append(torch.tensor([id]*self.num_steps))
        standard_labels = torch.stack(standard_labels)  # [B*T]
        return standard_labels.to(device)

    def get_chars_mixed_indexs(self, taski, labels, chars_each_task):
        '''
        convert each char to:
        multi-labels: torch.nn.MultiLabelMarginLoss()
        standard index labels: torch.nn.CrossEntropyLoss()
        '''
        standard_labels = []
        standard_index = []
        multi_labels = []
        multi_index = []

        for i in range(len(labels)):
            multi_label = []
            standard_label = []
            for j in range(len(labels[i])):  # len(labels[i]) << 63
                char_index = []
                for m in range(len(chars_each_task)):
                    if labels[i][j] in chars_each_task[m]:
                        char_index.append(m)

                if len(char_index) == 1:   # standard labels
                    standard_label.append(char_index)
                    standard_labels.extend(char_index)
                    standard_index.append((i, j))
                else:  # multi-labels
                    if len(char_index) < (taski+1):
                        char_index.append(-1)
                        if len(char_index) < (taski+1):
                            char_index.extend([-1] * (taski + 1 - len(char_index)))
                    multi_label.append(torch.tensor(char_index))
                    multi_labels.append(torch.tensor(char_index))
                    multi_index.append((i, j))
            if len(standard_label) + len(multi_label) < self.num_steps:  # fit the dim of predictions: [B, 63, D]
                for n in range(self.num_steps-(len(standard_label) + len(multi_label))):
                    multi_labels.append(torch.tensor([-1] * (taski+1)))
                    multi_index.append((i, j+n+1))
        standard_labels = torch.tensor(standard_labels)
        multi_labels = torch.stack(multi_labels)
        return standard_labels.to(device), standard_index, multi_labels.to(device), multi_index

    def get_words_mixed_indexs(self, taski, labels, current_indexs, chars_each_task):
        '''
        Convert each word to:
        multi-labels: the total chars in a word must belong to one language.
        standard index labels: the word is belong to the current language.
        '''
        id_standard = []
        id_multi = []
        index_labels_standard = []
        index_labels_multi = []
        for i in range(len(labels)):
            index = current_indexs[i]
            label_index = []  # contain each chars' initially index(s)
            multi_index = []   # the multi-label of current label.13=.bcfotx
            for j in range(len(labels[i])):
                char_index = []
                for m in range(len(chars_each_task)):
                    if labels[i][j] in chars_each_task[m] and m != index:
                        char_index.append(m)
                if len(char_index) == 0:
                    index_labels_standard.append(index)
                    id_standard.append(i)
                    break
                else:
                    label_index.append(char_index)
            if len(label_index) > 0 and len(char_index) > 0:
                for n in range(len(label_index[0])):
                    id = label_index[0][n]
                    flag = True
                    for char_index in label_index[1:]:
                        if id not in char_index:
                            if n == len(label_index[0])-1 and len(multi_index) == 0:
                                index_labels_standard.append(index)
                                id_standard.append(i)
                            flag = False
                            break
                    if flag is True:
                        multi_index.append(id)
                if len(multi_index) > 0:
                    multi_index.append(index)
                    if len(multi_index) < taski + 1:
                        multi_index.append(-1)
                        if len(multi_index) < taski + 1:
                            multi_index.extend([-1] * (taski + 1 - len(multi_index)))
                    multi_index = torch.tensor(multi_index)
                    index_labels_multi.append(multi_index)
                    id_multi.append(i)

        try:
            index_labels_standard = torch.stack(index_labels_standard)
        except:
            index_labels_standard = torch.tensor(index_labels_standard)  # for empty
        try:
            index_labels_multi = torch.stack(index_labels_multi)
        except:
            index_labels_multi = torch.tensor(index_labels_multi)
        assert len(labels) == len(index_labels_standard) + len(index_labels_multi)
        return index_labels_standard.to(device), id_standard, index_labels_multi.to(device), id_multi

    def get_taski_loss_chars(self, preds_all, index_standard, id_standard, index_multi, id_multi):
        '''
        preds_all: [B, T, D]
        '''
        preds_standard = []
        preds_multi = []
        for i, j in id_standard:
            preds_standard.append(preds_all[i][j])
        for i, j in id_multi:
            preds_multi.append(preds_all[i][j])
        preds_multi = torch.stack(preds_multi)
        preds_standard = torch.stack(preds_standard)

        loss_standard = self.taski_criterion(preds_standard, index_standard)
        loss_multi = self.taski_criterion_multi(preds_multi, index_multi)
        loss = loss_standard + loss_multi
        return loss

    def get_taski_loss_words(self, preds_all, index_standard, id_standard, index_multi, id_multi):
        '''
        preds_all: [B, D]
        '''
        preds_standard = []
        preds_multi = []
        if len(id_standard) > 0:
            for id in id_standard:
                preds_standard.append(preds_all[id])
            preds_standard = torch.stack(preds_standard)
            loss_standard = self.taski_criterion(preds_standard, index_standard)  # cross_entropy
        else:
            loss_standard = torch.tensor(0.)

        if len(id_multi) > 0:
            for id in id_multi:
                preds_multi.append(preds_all[id])
            preds_multi = torch.stack(preds_multi)
            loss_multi = self.taski_criterion_multi(preds_multi, index_multi)  # multiLabelMarginLoss()
        else:
            loss_multi = torch.tensor(0.)

        loss = loss_standard + loss_multi
        return loss

    def _update_representation(self, start_iter, taski, train_loader, valid_loader, chars_each_task, pi=15):  # result of pi=15 > result of pi=1
        # loss averager
        train_loss_avg = Averager()
        train_taski_loss_avg = Averager()

        self.criterion = self.build_criterion()
        filtered_parameters = self.count_param()
        # setup optimizer
        self.build_custom_optimizer(filtered_parameters, optimizer="adam", schedule="super", scale=1, the=2)

        start_time = time.time()
        self.best_score = -1

        for iteration in tqdm(
                range(start_iter + 1, self.opt.num_iter*2+1),
                total=self.opt.num_iter,
                position=0,
                leave=True,
        ):
            image_tensors, labels, indexs = train_loader.get_batch2()  # index is words index.
            indexs = torch.LongTensor(indexs).squeeze().to(self.device)
            image = image_tensors.to(self.device)
            labels_index, labels_length = self.converter.encode(labels, batch_max_length=self.opt.batch_max_length)
            batch_size = image.size(0)

            # convert to multi-label and standard labels for words and chars.
            index_labels_standard_words, id_standard_words, index_labels_multi_words, id_multi_words = self.get_words_mixed_indexs(taski, labels, indexs, chars_each_task)
            index_labels_standard_chars, id_standard_chars, index_labels_multi_chars, id_multi_chars = self.get_chars_mixed_indexs(taski, labels, chars_each_task)

            # labels_chars_indexs = self.get_chars_multiLabel_indexs(taski, labels, chars_each_task)  # convert to multi-label. only consider the labels before '-1'. for ablation study.
            # indexs_chars = self.get_chars_initial_indexs(indexs)  # convert to standard labels for chars, used for ablation study.

            # default recognition loss part
            if "CTC" in self.opt.Prediction:
                output = self.model(image, True)
                # [B,T,C,I], [B,I]
                '''Domain prediction loss: cross_entropy and multiLabelMarginLoss()'''
                # taski_loss = self.taski_criterion(output["index"], indexs)  # a: no softmax; b: directly label, not one-hot  # Cross_entropy Loss
                # taski_loss = self.taski_criterion_multi(output["index"].reshape(-1, taski+1), labels_chars_indexs.reshape(-1, taski+1))   # MultilabelMarginLoss
                taski_loss_words = self.get_taski_loss_words(output["index"], index_labels_standard_words, id_standard_words, index_labels_multi_words, id_multi_words)
                taski_loss_chars = self.get_taski_loss_chars(output["index_chars"], index_labels_standard_chars, id_standard_chars, index_labels_multi_chars, id_multi_chars)  # Mixed loss
                taski_loss = 0.5 * (taski_loss_words + taski_loss_chars)

                '''Cross_entropy for prediction'''

                preds = output["logits"]
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                # B，T，C(max) -> T, B, C
                preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2)
                loss_clf_words = self.criterion(preds_log_softmax, labels_index, preds_size, labels_length)

                preds_chars = output["logits_chars"]
                preds_size_chars = torch.IntTensor([preds_chars.size(1)] * batch_size)
                # B，T，C(max) -> T, B, C
                preds_log_softmax_chars = preds_chars.log_softmax(2).permute(1, 0, 2)
                loss_clf_chars = self.criterion(preds_log_softmax_chars, labels_index, preds_size_chars, labels_length)

                loss_clf = 0.5 * (loss_clf_words + loss_clf_chars)

            else:
                output = self.model(image, cross=True, text=labels_index[:, :-1], is_train=True)  # align with Attention.forward

                '''Incremental Estimation: 1.Word-level, 2.Char-level, 3.Compute Loss'''

                '''1. Word-level: Cross-entropy and multiLabelMarginLoss'''
                # taski_loss_words = self.taski_criterion(output["index"], indexs)  # a: no softmax; b: directly label, not one-hot  # only Cross_entropy, for ablation stydy.
                taski_loss_words = self.get_taski_loss_words(output["index"], index_labels_standard_words, id_standard_words, index_labels_multi_words, id_multi_words) # Mixed loss(default)

                '''2. Char-level: Cross-entropy and multiLabelMarginLoss'''
                # taski_loss_chars = self.taski_criterion(output["index_chars"].view(-1, taski+1), indexs_chars.view(-1))  # only Cross_entropy, for ablation stydy.
                taski_loss_chars = self.get_taski_loss_chars(output["index_chars"], index_labels_standard_chars, id_standard_chars, index_labels_multi_chars, id_multi_chars)  # Mixed loss for char-level

                '''3. Compute Loss'''
                taski_loss = 0.5 * (taski_loss_words + taski_loss_chars)
                # taski_loss = taski_loss_words
                # taski_loss = taski_loss_chars

                '''Cross_entropy for prediction'''
                preds = output["logits"]
                preds_chars = output["logits_chars"]
                target = labels_index[:, 1:]  # without [SOS] Symbol
                loss_clf_words = self.criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
                loss_clf_chars = self.criterion(preds_chars.view(-1, preds_chars.shape[-1]), target.contiguous().view(-1))
                loss_clf = 0.5 * (loss_clf_words + loss_clf_chars)

            loss = pi * loss_clf + taski_loss
            # loss.requires_grad_(True)
            self.model.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.opt.grad_clip)  # gradient clipping with 5 (Default)
            self.optimizer.step()
            train_loss_avg.add(loss_clf)
            train_taski_loss_avg.add(taski_loss)

            self.scheduler.step()
            if iteration % (self.opt.val_interval // 1) == 0 or iteration == self.opt.num_iter*2 or iteration == 1:
                self.val(valid_loader, self.opt, start_time, iteration, train_loss_avg, train_taski_loss_avg, taski, step=1, val_choose="TF")
                train_loss_avg.reset()
                train_taski_loss_avg.reset()


    def val(self, valid_loader, opt, start_time, iteration, train_loss_avg, train_taski_loss_avg, taski, step, val_choose="val"):
        self.model.eval()
        start_time = time.time()
        with torch.no_grad():
            (
                valid_loss,
                current_score,
                ned_score,
                preds,
                confidence_score,
                labels,
                infer_time,
                length_of_data,
            ) = validation(self.model, self.criterion, valid_loader, self.converter, opt, val_choose=val_choose)
        self.model.train()
        # keep best score (accuracy or norm ED) model on valid dataset
        # Do not use this on test datasets. It would be an unfair comparison
        # (training should be done without referring test set).
        if current_score > self.best_score:
            self.best_score = current_score
            name = opt.lan_list[taski]
            torch.save(self.model.state_dict(), f"./saved_models/{opt.exp_name}/{taski}_{name}_{step}_best_score.pth", )
        self.show_valid_log(opt, iteration, start_time, train_loss_avg, valid_loss, train_taski_loss_avg, current_score, ned_score, infer_time, length_of_data)

        ## show some predicted results
        # dashed_line = "-" * 80
        # head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
        # predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n"
        # for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
        #    if "Attn" in opt.Prediction:
        #        gt = gt[: gt.find("[EOS]")]
        #        pred = pred[: pred.find("[EOS]")]
        #    predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n"
        # predicted_result_log += f"{dashed_line}"
        # self.write_log(predicted_result_log + "\n")


    def test(self, AlignCollate_valid, valid_dataset_paths, best_scores, ned_scores, taski, forget17, forget19, total_scores17_each_task, total_scores19_each_task, val_choose="test"):
        print("---Start evaluation on benchmark testset----")
        if taski == 0:
            val_choose = "FF"
            step = 0
        else:
            val_choose = "TF"
            step = 1
        name = self.opt.lan_list[taski]
        saved_best_model = f"./saved_models/{self.opt.exp_name}/{taski}_{name}_{step}_best_score.pth"
        self.model.load_state_dict(torch.load(f"{saved_best_model}"), strict=True)

        task_accs = []
        ned_accs = []
        task_accs2 = []
        ned_accs2 = []
        for i, val_data in enumerate(valid_dataset_paths):
            valid_dataset, _ = hierarchical_dataset(root=val_data, opt=self.opt, mode="test")
            valid_loader = torch.utils.data.DataLoader(
                valid_dataset,
                batch_size=self.opt.batch_size,
                shuffle=True,  # 'True' to check training progress with validation function.
                num_workers=int(self.opt.workers),
                collate_fn=AlignCollate_valid,
                pin_memory=False,
            )

            self.model.eval()
            with torch.no_grad():
                (
                    valid_loss,
                    current_score,
                    ned_score,
                    preds,
                    confidence_score,
                    labels,
                    infer_time,
                    length_of_data,
                ) = validation(self.model, self.criterion, valid_loader, self.converter, self.opt, val_choose=val_choose)

            task_accs.append(round(current_score, 2))
            ned_accs.append(round(ned_score, 2))

        self.write_log(f"----------- {self.opt.exp_name} Task {taski}------------\n")
        if (taski + 1) * 2 == len(task_accs):
            avg_score17, avg_score19, avg_fg17, avg_fg19, total_scores17_each_task, total_scores19_each_task = self.double_write(taski, task_accs, total_scores17_each_task, total_scores19_each_task)

            best_scores.append(avg_score17)
            ned_scores.append(avg_score19)
            forget17.append(avg_fg17)
            forget19.append(avg_fg19)

            # acc_log = f'Task {taski} Avg Incremental Acc:  17: {best_scores[taski]}    19: {ned_scores[taski]}\n Task {taski} 17 Acc: {score17}\n 19 Acc: {score19}\n'
            # self.write_log(acc_log)
            # print(acc_log)
        else:
            best_scores.append(round(sum(task_accs) / len(task_accs), 2))
            ned_scores.append(round(sum(ned_accs) / len(ned_accs), 2))
            acc_log = f'Task {taski} Test Average Incremental Accuracy: {best_scores[taski]} \n Task {taski} Incremental Accuracy: {task_accs}\n ned_acc: {ned_accs}\n'
            # self.write_log(acc_log)
            # print(acc_log)
        print('--' * 80)
        return best_scores, ned_scores, forget17, forget19, total_scores17_each_task, total_scores19_each_task

    def show_valid_log(self, opt, iteration, start_time, train_loss_avg, valid_loss, train_taski_loss_avg, current_score, ned_score, infer_time, length_of_data):
        # validation log: loss, lr, score (accuracy or norm ED), time.
        lr = self.optimizer.param_groups[0]["lr"]
        elapsed_time = time.time() - start_time
        current_time = time.strftime('%Y-%m-%d %H:%M:%S')
        valid_log = f"{current_time} [{iteration}/{opt.num_iter}] Train_clf_loss: {train_loss_avg.val():0.5f}, Valid_loss: {valid_loss:0.5f}, "
        if train_taski_loss_avg != None:
            valid_log += f'Train_taski_loss: {train_taski_loss_avg.val():0.5f}, '
        # valid_log += f", Semi_loss: {semi_loss_avg.val():0.5f}\n"
        valid_log += f'Current_score: {current_score:0.2f}, Ned_score: {ned_score:0.2f}, Best_score: {self.best_score:0.2f}, '
        valid_log += f'Current_lr: {lr:0.7f}, Infer_time: {infer_time:0.2f}, Elapsed_time: {elapsed_time / length_of_data * 1000 :0.2f} '
        print(valid_log)
        self.write_log(valid_log + "\n")